%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%% Explanation of what the file does
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% 
% This function solves the model for different values of delta, ka, nu, etc
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%% Inputs needed
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% 
% 1. DNWR parameter (delta), scalar
% 2. Inverse migration elasticity (ka), scalar
% 3. Inverse sectoral elasticity (nu), scalar
% 4. Migration indicator (mi), scalar
% 5. Time changes in Chinese technology (tectimetemp), size 7*1
% 6. Sector changes in Chinese technology (tecsecttemp), size 1*12
% 7. DNWR option (optdelta), scalar 
% 8. Others like ba, zrs, AECR, AECRS, aeatw
% 
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%% Outputs produced
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% 
% 1. Several things

function [betabig,MWCW,MWCU,CUSEmp,CUSLSu,CUSUne,numrscon,USInfla] = ...
          NEST_OuterAlgoCF_FUNC(delta,ka,nu,mi,tectimetemp,...
          tecsecttemp,optdelta,ba,zrs,AECR,AECRS,aeatw)

% Section 1: Basic variables to be defined

S             = 14;
T             = 100;
I             = 87;
M             = 50;
sigma         = 6;
lecs          = length(tectimetemp);

% Import data

dataimp       = xlsread('Inputs/ProcessedData.xlsx','BIFIXED','C2:CK1219');
datarea       = reshape(dataimp',[I,I,S]);
labshares     = xlsread('Inputs/InputData.xlsx','VA','C2:P88');
inoutcdp      = xlsread('Inputs/InputData.xlsx','IO','D2:Q533');
inouttemp1    = permute(reshape(inoutcdp',[S,S,I-M+1]),[2,1,3]);
inouttemp2    = NaN(S,S,I);
inouttemp2(:,:,1:M) = repmat(inouttemp1(:,:,1),1,1,M);
inouttemp2(:,:,M+1:end) = inouttemp1(:,:,2:end);
inouttemp3    = permute(inouttemp2,[3,1,2]);
inoutmat      = inouttemp3 .* repmat(permute(1-labshares,[1,3,2]),1,S,1);

% Define trade shares, labor income, deficits, final cons. shares, etc

tradesharedat = datarea./repmat(sum(datarea,1),I,1,1);
expbysec      = squeeze(sum(datarea,1));
revbysec      = squeeze(sum(datarea,2));
laborincsec   = labshares .* revbysec;
deficits      = sum(expbysec,2)-sum(revbysec,2);
alphasinter   = repmat(permute(revbysec,[1,3,2]),1,S,1);
alphasregnum  = expbysec - sum(inoutmat .* alphasinter,3);
Amat          = alphasregnum./repmat(sum(alphasregnum,2),1,S);
Smat          = ones(1,S)*sigma;
DeltaMat      = 0.01*ones(I,S);
if optdelta == 1
    DeltaMat  = delta*ones(I,S);
elseif optdelta == 2
    DeltaMat(:,1:12) = delta*ones(I,12);
elseif optdelta == 3
    DeltaMat(1:M,:) = delta*ones(M,S);
elseif optdelta == 4
    DeltaMat(1:M,1:12) = delta*ones(M,12);
end

% Shocks

DefEra        = repmat(deficits,1,T);
TECmatC       = ones(I,S,T);
TAUmatC       = ones(I,I,S,T);

% Generate matrix of mobility for all countries (irrelevant for USA)

Splus         = S+1;
EllMinus      = xlsread('Inputs/ProcessedData.xlsx','L1999MIX','B2:P88');
EllInit       = xlsread('Inputs/InputData.xlsx','L2000CENSUS','B2:P88');
EllResh       = reshape(EllInit',I*Splus,1)';
EllDen        = repmat(sum(EllMinus,2),1,Splus)';
EllDiv        = EllResh./EllDen(:)';
EllDivE       = repmat(EllDiv,I*Splus,1);
MuInter       = kron(eye(I),ones(Splus));
MuInit        = MuInter.*EllDivE;

% Asign USA matrices

if mi == 1
    MuUSA     = xlsread('Inputs/ProcessedData.xlsx','MUFIXED','D2:ABY751');
elseif mi == 0
    MuUSA     = xlsread('Inputs/ProcessedData.xlsx','MUFIXEDNM','D2:ABY751');
elseif mi == 3
    UEllInit  = EllInit(1:M,:);
    UEllResh  = reshape(UEllInit',M*Splus,1)';
    UEllDiv   = UEllResh./sum(UEllInit(:));
    MuUSA     = repmat(UEllDiv,M*Splus,1);
end
if mi~=2
    MuInit(1:M*Splus,1:M*Splus) = MuUSA;
elseif mi == 2
    MuInit    = eye(I*Splus);
end
Ugu           = ones(I,Splus,T);

% Section 2: Solve for the baseline economy

lamdo         = min(min(nu,ka),0.6);
iterations    = 0;
distancebg    = 10;
maxiter       = 1500;
tolerance     = 3e-3;

% Run the contraction mapping loop part of the tatonnement algorithm 
% described in the paper

while distancebg>tolerance && iterations<maxiter
    [MLBas,lLBas] = NEBA_Block1(Ugu,MuInit,EllInit,ba,ka,nu);
    
    [WDBas,LDBas,PDBas,~,TLBas,YLBas] = NEBA_Block2(DefEra,...
    laborincsec,revbysec,tradesharedat,TAUmatC,TECmatC,labshares,...
    inoutmat,Smat,Amat,lLBas,DeltaMat,0.2,1e-4,AECR,AECRS,aeatw,M);

    UNew = NEBA_Block3(PDBas,Amat,WDBas,LDBas,lLBas,MLBas,ba,ka,nu,Ugu,zrs);
    
    iterations = iterations+1;
    distancebg = max(abs(UNew(:)-Ugu(:)));
    Ugu = lamdo * UNew + (1-lamdo)*Ugu;
end

% Section 3: Solve for the counterfactual economy

iterations    = 0;
distancebg    = 10;
Ugu2          = ones(I,Splus,T);
TECN          = ones(I,S,T);
tectime       = tectimetemp;
tecsect       = tecsecttemp';
tecvalpre     = (1+tectime*tecsect).^(sigma-1);
TECN(57,1:12,1:lecs) = permute(tecvalpre,[3,2,1]);
TecHat        = TECN./TECmatC;

% Run the contraction mapping loop part of the tatonnement algorithm 
% described in the paper

while distancebg>tolerance && iterations<maxiter
    [MLCfa,lLCfa,TZ] = NECO_Block1(Ugu2,MLBas,EllInit,ba,ka,nu);
    
    [WHCfa,LHCfa,PHCfa,~,~,YLCfa] = NECO_Block2(DefEra,...
    laborincsec,revbysec,tradesharedat,TAUmatC,TecHat,labshares,...
    inoutmat,Smat,Amat,lLCfa,DeltaMat,TLBas,WDBas,LDBas,0.3,1e-4,...
    AECR,AECRS,aeatw,M);

    [UNew2,Delrshat] = NECO_Block3(PHCfa,Amat,WHCfa,LHCfa,lLCfa,MLCfa,ba,ka,...
    nu,MLBas,TZ,lLBas,Ugu2,zrs,LDBas);
    
    iterations = iterations+1;
    distancebg = max(abs(UNew2(:)-Ugu2(:)));
    Ugu2 = lamdo * UNew2 + (1-lamdo) * Ugu2; 
end

Pagg               = repmat(prod(PHCfa .^ repmat(Amat,1,1,T-1),2),1,S,1);
EllRel             = lLCfa(:,2:end,:);
EllBR              = lLBas(:,2:end,:);
EllDot             = EllRel(:,:,2:end)./EllRel(:,:,1:end-1);
EllDBas            = EllBR(:,:,2:end)./EllBR(:,:,1:end-1);
EllHat             = EllDot./EllDBas;
Omega              = ones(I,Splus,T-1);
Omega(:,2:end,:)   = (WHCfa.*LHCfa)./(Pagg.*EllHat.*Delrshat);

% Compute mu dots in baseline

MuHashB            = NaN(I*Splus,I,T);
MuCondB            = NaN(size(MLBas));
MuHashC            = NaN(I*Splus,I,T);
MuCondC            = NaN(size(MLCfa));
for t=1:T
    MuHashB(:,:,t) = MLBas(:,:,t) * kron(eye(I),ones(Splus,1));
    MuCondB(:,:,t) = MLBas(:,:,t) ./ kron(MuHashB(:,:,t),ones(1,Splus));
    MuHashC(:,:,t) = MLCfa(:,:,t) * kron(eye(I),ones(Splus,1));
    MuCondC(:,:,t) = MLCfa(:,:,t) ./ kron(MuHashC(:,:,t),ones(1,Splus));
end
MuCondB(isnan(MuCondB)) = 0;
MuCondC(isnan(MuCondC)) = 0;
MuHashDB           = MuHashB(:,:,2:end)./MuHashB(:,:,1:end-1);
MuHashDB(isnan(MuHashDB)) = 1;
MuCondDB           = MuCondB(:,:,2:end)./MuCondB(:,:,1:end-1);
MuCondDB(isnan(MuCondDB)) = 1;
MuHashDC           = MuHashC(:,:,2:end)./MuHashC(:,:,1:end-1);
MuHashDC(isnan(MuHashDC)) = 1;
MuCondDC           = MuCondC(:,:,2:end)./MuCondC(:,:,1:end-1);
MuCondDC(isnan(MuCondDC)) = 1;

% Compute mu hats

MuCondHat          = MuCondDC./MuCondDB;
MuCondT1           = NaN(I*Splus,1,T-1);
for counter=1:I*Splus
    MuCondT1(counter,1,:) = MuCondHat(counter,counter,:);
end
MuCondT2           = permute(reshape(MuCondT1,Splus,I,T-1),[2,1,3]);

MuHashHat          = MuHashDC./MuHashDB;
MuHashT1           = NaN(I*Splus,1,T-1);
for counter=1:I*Splus
    MuHashT1(counter,1,:) = MuHashHat(counter,ceil(counter/Splus),:);
end
MuHashT2           = permute(reshape(MuHashT1,Splus,I,T-1),[2,1,3]);

% Finally compute welfare change

WELF1              = log(Omega./( MuCondT2.^nu .* MuHashT2.^ka));
WELF2              = linspace(1,T-1,T-1)';
WELF3              = permute(WELF2,[2,3,1]);
WELF4              = repmat(WELF3,I,Splus,1);
WELF5              = ba.^WELF4 .* WELF1;
WELF               = sum(WELF5,3);

clearvars Mu* MLBas MLCfa

% Section 4: Compute stats

LDCfa              = LHCfa .* LDBas;
LLBas              = lLBas;
LLBas(:,2:end,2:end) = repmat(lLBas(:,2:end,1),1,1,T-1).*cumprod(LDBas,3);
LLCfa              = lLCfa;
LLCfa(:,2:end,2:end) = repmat(lLCfa(:,2:end,1),1,1,T-1).*cumprod(LDCfa,3);
bartikjose         = xlsread('Inputs/InputData.xlsx','EXP','B2:B51');
barr               = bartikjose*2.63/mean(bartikjose);

% Compute baseline things

PopulationB        = squeeze(sum(lLBas(1:M,:,:),2));
TotalLaborSupplyB  = squeeze(sum(lLBas(1:M,2:end,:),2));
TotalEmploymentB   = squeeze(sum(LLBas(1:M,2:end,:),2));
MoPB               = squeeze(sum(LLBas(1:M,2:13,:),2))./PopulationB;
NoPB               = squeeze(sum(LLBas(1:M,14:15,:),2))./PopulationB;
UoPB               = (TotalLaborSupplyB-TotalEmploymentB)./PopulationB;
loPB               = TotalLaborSupplyB./PopulationB;
WagebillMB         = squeeze(sum(YLBas(1:M,1:12,:),2));
LaborMB            = squeeze(sum(LLBas(1:M,2:13,:),2));
WagesMB            = WagebillMB./LaborMB;
WagebillNB         = squeeze(sum(YLBas(1:M,13:end,:),2));
LaborNB            = squeeze(sum(LLBas(1:M,14:end,:),2));
WagesNB            = WagebillNB./LaborNB;

% Compute counterfactual things

PopulationC        = squeeze(sum(lLCfa(1:M,:,:),2));
TotalLaborSupplyC  = squeeze(sum(lLCfa(1:M,2:end,:),2));
TotalEmploymentC   = squeeze(sum(LLCfa(1:M,2:end,:),2));
MoPC               = squeeze(sum(LLCfa(1:M,2:13,:),2))./PopulationC;
NoPC               = squeeze(sum(LLCfa(1:M,14:15,:),2))./PopulationC;
UoPC               = (TotalLaborSupplyC-TotalEmploymentC)./PopulationC;
loPC               = TotalLaborSupplyC./PopulationC;
WagebillMC         = squeeze(sum(YLCfa(1:M,1:12,:),2));
LaborMC            = squeeze(sum(LLCfa(1:M,2:13,:),2));
WagesMC            = WagebillMC./LaborMC;
WagebillNC         = squeeze(sum(YLCfa(1:M,13:end,:),2));
LaborNC            = squeeze(sum(LLCfa(1:M,14:end,:),2));
WagesNC            = WagebillNC./LaborNC;

betabig            = zeros(7,30);

for tin=2:30
    MoPchan        = (mean(MoPC(:,tin-1:tin+1),2)-mean(MoPB(:,tin-1:tin+1),2))*100;
    NoPchan        = (mean(NoPC(:,tin-1:tin+1),2)-mean(NoPB(:,tin-1:tin+1),2))*100;
    UoPchan        = (mean(UoPC(:,tin-1:tin+1),2)-mean(UoPB(:,tin-1:tin+1),2))*100;
    loPchan        = (mean(loPC(:,tin-1:tin+1),2)-mean(loPB(:,tin-1:tin+1),2))*100;
    PoPchan        = (PopulationC(:,tin)-PopulationB(:,tin))./PopulationB(:,1)*100;
    MWchan         = ((WagesMC(:,tin)-WagesMB(:,tin))./WagesMB(:,1))*100;
    NWchan         = ((WagesNC(:,tin)-WagesNB(:,tin))./WagesNB(:,1))*100;
    X              = [ones(M,1),barr];
    Y              = [UoPchan,-loPchan,PoPchan,MoPchan,NoPchan,MWchan,NWchan]*(10/(tin-1));
    coeff          = (X'*X)^(-1)*(X'*Y);
    betabig(:,tin) = coeff(2,:)';
end

% Welfare calculations

lShaReg            = lLBas(:,:,2)./repmat(sum(lLBas(:,:,2),2),1,Splus);
AggWelf            = sum(WELF .* lShaReg,2)*100;
AggWelfU           = AggWelf(1:M);
PopSta             = sum(lLBas(1:50,:,2),2);
PopSha             = PopSta./sum(PopSta);
MWCW               = PopSha'*AggWelfU;
MWCU               = 1/50*ones(1,M)*AggWelfU;

% Total US employment and labor supply

USEmpC             = sum(TotalEmploymentC);
USEmpB             = sum(TotalEmploymentB);
USLSuC             = sum(TotalLaborSupplyC);
USLSuB             = sum(TotalLaborSupplyB);
USUneC             = (1-USEmpC./USLSuC)*100;
USUneB             = (1-USEmpB./USLSuB)*100;
CUSEmp             = ((USEmpC-USEmpB)/USEmpB(1))*100;
CUSLSu             = ((USLSuC-USLSuB)/USLSuB(1))*100;
CUSUne             = USUneC-USUneB;

% Sectors that are constrained

WDCfa              = WHCfa .* WDBas;
bM                 = squeeze(sum(sum(WDCfa(1:50, 1:12,:)<repmat(DeltaMat(1:50, 1:12),1,1,T-1)+0.00001)));
bN                 = squeeze(sum(sum(WDCfa(1:50,13:14,:)<repmat(DeltaMat(1:50,13:14),1,1,T-1)+0.00001)));
RelEmploy          = lLCfa(1:50,2:end,2:end);
indices            = WDCfa(1:50, 1:end,:)<repmat(DeltaMat(1:50, 1:end),1,1,T-1)+0.00001;
sM                 = squeeze(sum(sum(RelEmploy.*indices))./sum(sum(RelEmploy))*100);
sN                 = squeeze(sum(sum(RelEmploy.*indices))./sum(sum(RelEmploy(:,1:12,:)))*100);
numrscon           = [0,bM';0,bN';0,sM';0,sN'];

% U.S. aggregate inflation

Pinf               = (squeeze(Pagg(1:50,1,:))-1)*100;
Pweights           = squeeze(sum(lLCfa(1:50,:,:),2)/sum(sum(lLCfa(1:50,:,1),2)));
USInfla            = [0,sum(Pinf.*Pweights(:,1:end-1))];

end
